{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Forward-mode Automatic Differentiation (Beta)\n\nThis tutorial demonstrates how to use forward-mode AD to compute\ndirectional derivatives (or equivalently, Jacobian-vector products).\n\nThe tutorial below uses some APIs only available in versions >= 1.11\n(or nightly builds).\n\nAlso note that forward-mode AD is currently in beta. The API is\nsubject to change and operator coverage is still incomplete.\n\n## Basic Usage\nUnlike reverse-mode AD, forward-mode AD computes gradients eagerly\nalongside the forward pass. We can use forward-mode AD to compute a\ndirectional derivative by performing the forward pass as before,\nexcept we first associate our input with another tensor representing\nthe direction of the directional derivative (or equivalently, the ``v``\nin a Jacobian-vector product). When an input, which we call \"primal\", is\nassociated with a \"direction\" tensor, which we call \"tangent\", the\nresultant new tensor object is called a \"dual tensor\" for its connection\nto dual numbers[0].\n\nAs the forward pass is performed, if any input tensors are dual tensors,\nextra computation is performed to propogate this \"sensitivity\" of the\nfunction.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport torch.autograd.forward_ad as fwAD\n\nprimal = torch.randn(10, 10)\ntangent = torch.randn(10, 10)\n\ndef fn(x, y):\n    return x ** 2 + y ** 2\n\n# All forward AD computation must be performed in the context of\n# a ``dual_level`` context. All dual tensors created in such a context\n# will have their tangents destroyed upon exit. This is to ensure that\n# if the output or intermediate results of this computation are reused\n# in a future forward AD computation, their tangents (which are associated\n# with this computation) won't be confused with tangents from the later\n# computation.\nwith fwAD.dual_level():\n    # To create a dual tensor we associate a tensor, which we call the\n    # primal with another tensor of the same size, which we call the tangent.\n    # If the layout of the tangent is different from that of the primal,\n    # The values of the tangent are copied into a new tensor with the same\n    # metadata as the primal. Otherwise, the tangent itself is used as-is.\n    #\n    # It is also important to note that the dual tensor created by\n    # ``make_dual`` is a view of the primal.\n    dual_input = fwAD.make_dual(primal, tangent)\n    assert fwAD.unpack_dual(dual_input).tangent is tangent\n\n    # To demonstrate the case where the copy of the tangent happens,\n    # we pass in a tangent with a layout different from that of the primal\n    dual_input_alt = fwAD.make_dual(primal, tangent.T)\n    assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent\n\n    # Tensors that do not not have an associated tangent are automatically\n    # considered to have a zero-filled tangent of the same shape.\n    plain_tensor = torch.randn(10, 10)\n    dual_output = fn(dual_input, plain_tensor)\n\n    # Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``\n    # as attributes\n    jvp = fwAD.unpack_dual(dual_output).tangent\n\nassert fwAD.unpack_dual(dual_output).tangent is None"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Usage with Modules\nTo use ``nn.Module`` with forward AD, replace the parameters of your\nmodel with dual tensors before performing the forward pass. At the\ntime of writing, it is not possible to create dual tensor\n`nn.Parameter`s. As a workaround, one must register the dual tensor\nas a non-parameter attribute of the module.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch.nn as nn\n\nmodel = nn.Linear(5, 5)\ninput = torch.randn(16, 5)\n\nparams = {name: p for name, p in model.named_parameters()}\ntangents = {name: torch.rand_like(p) for name, p in params.items()}\n\nwith fwAD.dual_level():\n    for name, p in params.items():\n        delattr(model, name)\n        setattr(model, name, fwAD.make_dual(p, tangents[name]))\n\n    out = model(input)\n    jvp = fwAD.unpack_dual(out).tangent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Using Modules stateless API (experimental)\nAnother way to use ``nn.Module`` with forward AD is to utilize\nthe stateless API. NB: At the time of writing the stateless API is still\nexperimental and may be subject to change.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from torch.nn.utils._stateless import functional_call\n\n# We need a fresh module because the functional call requires the\n# the model to have parameters registered.\nmodel = nn.Linear(5, 5)\n\ndual_params = {}\nwith fwAD.dual_level():\n    for name, p in params.items():\n        # Using the same ``tangents`` from the above section\n        dual_params[name] = fwAD.make_dual(p, tangents[name])\n    out = functional_call(model, dual_params, input)\n    jvp2 = fwAD.unpack_dual(out).tangent\n\n# Check our results\nassert torch.allclose(jvp, jvp2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Custom autograd Function\nCustom Functions also support forward-mode AD. To create custom Function\nsupporting forward-mode AD, register the ``jvp()`` static method. It is\npossible, but not mandatory for custom Functions to support both forward\nand backward AD. See the\n[documentation](https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad)\nfor more information.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Fn(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, foo):\n        result = torch.exp(foo)\n        # Tensors stored in ctx can be used in the subsequent forward grad\n        # computation.\n        ctx.result = result\n        return result\n\n    @staticmethod\n    def jvp(ctx, gI):\n        gO = gI * ctx.result\n        # If the tensor stored in ctx will not also be used in the backward pass,\n        # one can manually free it using ``del``\n        del ctx.result\n        return gO\n\nfn = Fn.apply\n\nprimal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)\ntangent = torch.randn(10, 10)\n\nwith fwAD.dual_level():\n    dual_input = fwAD.make_dual(primal, tangent)\n    dual_output = fn(dual_input)\n    jvp = fwAD.unpack_dual(dual_output).tangent\n\n# It is important to use ``autograd.gradcheck`` to verify that your\n# custom autograd Function computes the gradients correctly. By default,\n# gradcheck only checks the backward-mode (reverse-mode) AD gradients. Specify\n# ``check_forward_ad=True`` to also check forward grads. If you did not\n# implement the backward formula for your function, you can also tell gradcheck\n# to skip the tests that require backward-mode AD by specifying\n# ``check_backward_ad=False``, ``check_undefined_grad=False``, and\n# ``check_batched_grad=False``.\ntorch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,\n                         check_backward_ad=False, check_undefined_grad=False,\n                         check_batched_grad=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Functional API (beta)\nWe also offer a higher-level functional API in functorch\nfor computing Jacobian-vector products that you may find simpler to use\ndepending on your use case.\n\nThe benefit of the functional API is that there isn't a need to understand\nor use the lower-level dual tensor API and that you can compose it with\nother [functorch transforms (like vmap)](https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html);\nthe downside is that it offers you less control.\n\nNote that the remainder of this tutorial will require functorch\n(https://github.com/pytorch/functorch) to run. Please find installation\ninstructions at the specified link.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import functorch as ft\n\nprimal0 = torch.randn(10, 10)\ntangent0 = torch.randn(10, 10)\nprimal1 = torch.randn(10, 10)\ntangent1 = torch.randn(10, 10)\n\ndef fn(x, y):\n    return x ** 2 + y ** 2\n\n# Here is a basic example to compute the JVP of the above function.\n# The jvp(func, primals, tangents) returns func(*primals) as well as the\n# computed jvp. Each primal must be associated with a tangent of the same shape.\nprimal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))\n\n# functorch.jvp requires every primal to be associated with a tangent.\n# If we only want to associate certain inputs to `fn` with tangents,\n# then we'll need to create a new function that captures inputs without tangents:\nprimal = torch.randn(10, 10)\ntangent = torch.randn(10, 10)\ny = torch.randn(10, 10)\n\nimport functools\nnew_fn = functools.partial(fn, y=y)\nprimal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Using the functional API with Modules\nTo use ``nn.Module`` with functorch.jvp to compute Jacobian-vector products\nwith respect to the model parameters, we need to reformulate the\n``nn.Module`` as a function that accepts both the model parameters and inputs\nto the module.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = nn.Linear(5, 5)\ninput = torch.randn(16, 5)\ntangents = tuple([torch.rand_like(p) for p in model.parameters()])\n\n# Given a torch.nn.Module, ft.make_functional_with_buffers extracts the state\n# (params and buffers) and returns a functional version of the model that\n# can be invoked like a function.\n# That is, the returned ``func`` can be invoked like\n# ``func(params, buffers, input)``.\n# ft.make_functional_with_buffers is analogous to the nn.Modules stateless API\n# that you saw previously and we're working on consolidating the two.\nfunc, params, buffers = ft.make_functional_with_buffers(model)\n\n# Because jvp requires every input to be associated with a tangent, we need to\n# create a new function that, when given the parameters, produces the output\ndef func_params_only(params):\n    return func(params, buffers, input)\n\nmodel_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[0] https://en.wikipedia.org/wiki/Dual_number\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}